import torch

from ppl_eval import measure_generation_latency
from transformers.models.llama import modeling_llama
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer


def get_model(model_dir):
    def skip(*args, **kwargs):
        pass
    torch.nn.init.kaiming_uniform_ = skip
    torch.nn.init.uniform_ = skip
    torch.nn.init.normal_ = skip

    config = AutoConfig.from_pretrained(model_dir, trust_remote_code=False, local_files_only=True)
    model = AutoModelForCausalLM.from_pretrained(
        model_dir,
        config=config,
        torch_dtype=torch.bfloat16,
        trust_remote_code=True,
        local_files_only=True,
    )
    tokenizer = AutoTokenizer.from_pretrained(
        model_dir,
        trust_remote_code=True,
        local_files_only=True,
        use_fast=False,  
    )
    if tokenizer.bos_token is None:
        tokenizer.bos_token = "<s>"
    if tokenizer.eos_token is None:
        tokenizer.eos_token = "</s>"
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    if tokenizer.unk_token is None:
        tokenizer.unk_token = tokenizer.eos_token
    return model, tokenizer


def print_latency_report():
    print("[Latency Summary]")
    count = modeling_llama.latency_tracker["count"]
    for k in ["input_layernorm", "self_attn", "post_attention_layernorm", "mlp", "total"]:
        total_time = modeling_llama.latency_tracker[k]
        avg_time = (total_time / count) * 1000 if count > 0 else 0.0
        print(f"{k:>24}: total = {total_time:.3f}s, avg = {avg_time:.3f}ms")


if __name__ == "__main__":
    print('load model...')
    model, tokenizer = get_model("../llama2-13b-local")
    model.seqlen = 128
    model.eval()

    print('start evaluate...')
    torch.cuda.reset_peak_memory_stats()
    measure_generation_latency(model.cuda().half(), tokenizer, "Once upon a time in a land far, far away,", 512, 32, device="cuda")
    torch.cuda.synchronize()
    peak_memory = torch.cuda.max_memory_allocated() / (1024 * 1024)  # Convert to MB
    print(f"[LOG] layer_num: {modeling_llama.layer_num}")
    print(f"[LOG] Skipped layers: {modeling_llama.skipped_layer_count}")
    if modeling_llama.layer_num > 0:
        print(f"[LOG] skipped ratio: {modeling_llama.skipped_layer_count/modeling_llama.layer_num}")
    print(f"[LOG] Peak GPU memory: {peak_memory:.2f} MB")
    print_latency_report()